import torch

a = torch.rand(3, 4)
print(a)

out = torch.chunk(a, 2, dim=0)
# print(out[0], out[0].shape)
# print(out[1], out[1].shape)

# tensor([[0.9845, 0.8561, 0.3852, 0.2004], 
#         [0.6094, 0.0637, 0.3666, 0.6261], 
#         [0.2000, 0.4750, 0.9271, 0.6911]])
# tensor([[0.9845, 0.8561, 0.3852, 0.2004],
#         [0.6094, 0.0637, 0.3666, 0.6261]]) torch.Size([2, 4])
# tensor([[0.2000, 0.4750, 0.9271, 0.6911]]) torch.Size([1, 4])

out = torch.split(a, [1, 2], dim=0)
print(out[0], out[0].shape)
print(out[1], out[1].shape)